Deep Learning - Mushroom Identification¶

Disclaimer: ⚠️ Do not use this code to determine what to put in your omelette! 🍳

It’s important to note that the identification and classification of mushrooms should be done carefully, as some mushrooms can be toxic or deadly. Always consult with a mycologist or use a reliable field guide when identifying mushrooms

💡 📢 ☑️ remember to read the readme.md file for helpful hints on the best ways to view/navigate this project

If you visualize this notebook on github you will be missing important content

Some charts/diagrams/features are not visible in github. This is standard and well-known behaviour.

Consider viewing the pre-rendered HTML files, or run all notebooks end to end after enabling the feature flags that control long running operations:

If you chose to run this locally, there are some prerequisites:

  • you will need python 3.9
  • you will need to install the dependencies using pip install -r requirements.txt before proceeding.

Problem Statement (provided by Turing College)¶

Sprint 1: Computer Vision - Know Your Mushrooms

# Background

In US alone, around 7500 yearly cases of mushrooms poisoning are reported [(Source)](https://www.tandfonline.com/doi/full/10.1080/00275514.2018.1479561). According to the source, "misidentification of edible mushroom species appears to be the most common cause and may be preventable through education". To avoid expenses for hospitalization and in some cases pointless deaths, you have been hired by US National Health Service to create a machine learning model, that can recognize mushroom types. They want to install this on hand-held devices and to help people make the right choice when mushroom picking.

# Concepts to explore

Today, we will put everything we learned in this module and use it to solve a classification problem. The idea of this project is to use transfer learning on an architecture of your choice and fine-tune to predict mushroom types.

You will use this Kaggle dataset <https://www.kaggle.com/maysee/mushrooms-classification-common-genuss-images>

# How to start?

## Data

Well, the obvious first steps will be getting the data from Kaggle. There are a number of choices on how to do it, such as downloading images to your machine and then uploading to Drive or using [Kaggle API](https://github.com/Kaggle/kaggle-api). Once you get your data, start with an EDA, as this will directly feed into design choices for your architecture.

## Modeling

My suggestion is that you start with a simple pre-trained architecture, like ResNet18. This will allow you to fine-tune your net faster and if results are not too good, you can try switching to a larger model later. It is recommended that you use PyTorch Lightning or FastAI. Both are equally good for simple problems like this, but PyTorch Lightning will give you more control, better customization ability, and better understanding of your network.

# Requirements

- Choose whichever framework you prefer from FastAI, PyTorch Lightning or PyTorch.
- As always - EDA
- Use a pre-trained neural net as a backbone of your class
- Train a classifier. Don't forget to fine-tune
- Evaluate inference time
- Visualize results

# Evaluation Criteria
- Model performance
    - Classification performance
    - Inference speed
- EDA and documented findings
- Results analysis
- Code quality

# Bonus challenges
- Repeat the process with modifications to your network and see how the results vary.
    - Try a different optimizer
    - Add an intermediate layer between the backbone and output layer

# Sample correction questions

During a correction, you may get asked questions that test your understanding of covered topics.

- Describe how a convolutional layers works
- What is overfitting? Describe why is it bad/good and how to detect it?
- What is an optimizer? Describe in high level how it works.
- What are the advantages/disadvantages of transfer learning?

Getting Started¶

💡 After receiving feedback in past project that I tend to murder my reviewers by going too in depth, with excessive detail and explaining too many things, I will try to keep the descriptions for "what I'm thinking/considering" a bit more concise. Happy to hear your feedback.

Imports and initial setup¶

In [1]:
from IPython.display import display, Markdown, clear_output, HTML, IFrame
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import itertools
import glob

import numpy as np
import pandas as pd
import seaborn as sns

import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms, datasets, models
import torch.nn.functional as F
from torchmetrics.classification import (
    MulticlassConfusionMatrix,
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from torchmetrics import ConfusionMatrix

# import torchvision.models as models

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from optuna.integration import PyTorchLightningPruningCallback

from PIL import Image, ImageDraw, ImageFile
import albumentations as A
import plotly.express as px
import plotly.io as pio
from scipy import stats
from scipy.stats import chi2_contingency
import missingno as msno

from sklearn.metrics import ConfusionMatrixDisplay

from random import random, seed, shuffle
import logging
import warnings

import os
from os import path

from utils import *
from utils import __
loading utils modules... 
/home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/xgboost/compat.py:93: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
  from pandas import MultiIndex, Int64Index
✅ completed
configuring autoreload... ✅ completed
In [2]:
from keras.utils import plot_model
from keras.applications.resnet50 import ResNet50
2023-12-10 09:43:11.636571: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-10 09:43:11.654057: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-10 09:43:11.654071: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-10 09:43:11.654700: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-10 09:43:11.657998: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-10 09:43:12.044040: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
In [3]:
print(f"{np.__version__    = }")
print(f"{pd.__version__    = }")
print(f"{torch.__version__ = }")
print(f"{pl.__version__    = }")
np.__version__    = '1.23.5'
pd.__version__    = '1.5.3'
torch.__version__ = '2.1.0+cu121'
pl.__version__    = '2.1.1'
In [4]:
seed(100)
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 50
pio.renderers.default = "png"
ImageFile.LOAD_TRUNCATED_IMAGES = True
util.check("done")
✅
In [5]:
%reload_ext mushrooms_utils

import mushrooms_utils as mushroom

Let's use black to auto-format all our cells so they adhere to PEP8

In [6]:
import lab_black

%reload_ext lab_black
util.patch_nb_black()
# fmt: off
# fmt: on
In [7]:
from sklearn import set_config

set_config(transform_output="pandas")
In [8]:
sns.set_theme(context="notebook", style="whitegrid")

moonstone = "#62b6cb"
moonstone_rgb = util.hex_to_rgb(moonstone)
moonstone_rgb_n = np.array(moonstone_rgb) / 255
In [9]:
logger = util.configure_logging(jupyterlab_level=logging.WARN, file_level=logging.DEBUG)

warnings.filterwarnings("ignore", category=FutureWarning)

# import warnings
# warnings.filterwarnings('error', category=pd.errors.DtypeWarning)
In [10]:
def ding(title="Ding!", message="Task completed"):
    """
    this method only works on linux
    """
    for i in range(2):
        !notify-send '{title}' '{message}'

Feature Toggles¶

Let's also create a simple feature toggle that we can use to skip expensive operations during notebook work (to save myself some time!)

Set it to true if you want to run absolutely everything. Set to false to skip optional steps/exploratory work.

In [11]:
def run_entire_notebook(filename: str = None):
    run_all = False
    if not run_all:
        print("skipping optional operation")
        fullpath = f"cached/printouts/{filename}.txt"
        if filename is not None and os.path.exists(fullpath):
            print("==== 🗃️ printing cached output ====")
            with open(fullpath) as f:
                print(f.read())
    return run_all

Fetching Data¶

In [12]:
kaggle_dataset_name = "maysee/mushrooms-classification-common-genuss-images"
db_filename = "Mushrooms"

auto_kaggle.download_dataset(kaggle_dataset_name, db_filename, timeout_seconds=3 * 60)
__
Kaggle API 1.5.13 - login as 'edualmas'
File [dataset/Mushrooms] already exists locally!
No need to re-download data for dataset [maysee/mushrooms-classification-common-genuss-images]

A quick check shows that the dataset, has 2 identical folders, with the data duplicated.

We will ignore one of the directories altogether

In [13]:
if os.path.exists("dataset/mushrooms/"):
    !rm -rf dataset/mushrooms/
    print("removed duplicated images")
else:
    print("folder with duplicated images has already been removed")
folder with duplicated images has already been removed

Splitting files into Train/Hp/Val/Test datasets¶

We need to split our dataset into a few chunks.

Since there seems to be no standard way out of the box to achieve all of these things:

  • do train test split of various chunks
  • create a ImageFolder that points to a specific dir
    • and to be able to split that
    • but at the same time, be able to apply different transforms to different chunks of the files *

We have to do the split ourselves.

* = The pytorch SubSets are basically a view of the original dataset and they all share the same transforms. This would be normal, but we want to apply different preprocessing to different splits:

  • train: basic img->tensor + resize/normalize + rebalancing + data augmentation with image transformations
  • hp: basic img->tensor + resize/normalize
  • val: basic img->tensor + resize/normalize
  • test: basic img->tensor + resize/normalize

So, here we go again, reinventing the wheel.

In [14]:
if not os.path.exists("dataset/Mushrooms_split/"):
    operation_log = split_utils.split_image_dataset(
        "dataset/Mushrooms",
        {"train": 0.65, "hp": 0.10, "val": 0.10, "test": 0.15},
        glob_filename_pattern="*.jpg",
    )
    display(operation_log.head())
    assert operation_log["source_file"].duplicated().any() == False
else:
    print("dataset is already split in [dataset/Mushrooms_split]")
dataset is already split in [dataset/Mushrooms_split]

PASS: each file has only been copied once ✅

No data leakage ✅

Let's also check that we lost no data:

In [15]:
with fs_utils.in_subdir("dataset/"):
    !find Mushrooms -type d -exec sh -c 'echo files: "$(find "{}" -type f | wc -l) \t {}"' \;
files: 6714 	 Mushrooms
files: 1073 	 Mushrooms/Boletus
files: 353 	 Mushrooms/Agaricus
files: 750 	 Mushrooms/Amanita
files: 1563 	 Mushrooms/Lactarius
files: 311 	 Mushrooms/Suillus
files: 836 	 Mushrooms/Cortinarius
files: 364 	 Mushrooms/Entoloma
files: 316 	 Mushrooms/Hygrocybe
files: 1148 	 Mushrooms/Russula
In [16]:
with fs_utils.in_subdir("dataset/"):
    !find Mushrooms_split -type d -exec sh -c 'echo files: "$(find "{}" -type f | wc -l) \t {}"' \;
files: 6714 	 Mushrooms_split
files: 4364 	 Mushrooms_split/train
files: 697 	 Mushrooms_split/train/Boletus
files: 230 	 Mushrooms_split/train/Agaricus
files: 488 	 Mushrooms_split/train/Amanita
files: 1016 	 Mushrooms_split/train/Lactarius
files: 202 	 Mushrooms_split/train/Suillus
files: 543 	 Mushrooms_split/train/Cortinarius
files: 237 	 Mushrooms_split/train/Entoloma
files: 205 	 Mushrooms_split/train/Hygrocybe
files: 746 	 Mushrooms_split/train/Russula
files: 1008 	 Mushrooms_split/test
files: 162 	 Mushrooms_split/test/Boletus
files: 53 	 Mushrooms_split/test/Agaricus
files: 112 	 Mushrooms_split/test/Amanita
files: 235 	 Mushrooms_split/test/Lactarius
files: 47 	 Mushrooms_split/test/Suillus
files: 125 	 Mushrooms_split/test/Cortinarius
files: 55 	 Mushrooms_split/test/Entoloma
files: 47 	 Mushrooms_split/test/Hygrocybe
files: 172 	 Mushrooms_split/test/Russula
files: 671 	 Mushrooms_split/val
files: 107 	 Mushrooms_split/val/Boletus
files: 35 	 Mushrooms_split/val/Agaricus
files: 75 	 Mushrooms_split/val/Amanita
files: 156 	 Mushrooms_split/val/Lactarius
files: 31 	 Mushrooms_split/val/Suillus
files: 84 	 Mushrooms_split/val/Cortinarius
files: 36 	 Mushrooms_split/val/Entoloma
files: 32 	 Mushrooms_split/val/Hygrocybe
files: 115 	 Mushrooms_split/val/Russula
files: 671 	 Mushrooms_split/hp
files: 107 	 Mushrooms_split/hp/Boletus
files: 35 	 Mushrooms_split/hp/Agaricus
files: 75 	 Mushrooms_split/hp/Amanita
files: 156 	 Mushrooms_split/hp/Lactarius
files: 31 	 Mushrooms_split/hp/Suillus
files: 84 	 Mushrooms_split/hp/Cortinarius
files: 36 	 Mushrooms_split/hp/Entoloma
files: 32 	 Mushrooms_split/hp/Hygrocybe
files: 115 	 Mushrooms_split/hp/Russula

Exploratory Data Analysis¶

Let's take a look at the raw dataset that we have retrieved from kaggle

In [17]:
import matplotlib.pyplot as plt
from PIL import Image
import os


@run
@cached_chart(extension="jpg")
def mushroom_pics_sample():
    # yes, "genera" would be a more scientifically accurate name...
    # but we're aiming for code that is easy to read by
    # a non-scientific audience (us, developers!)
    all_genuses = mushroom.Genus
    imgs_per_genus = 10

    f, ax = plt.subplots(len(all_genuses), imgs_per_genus, figsize=(15, 15))

    for g, genus in enumerate(all_genuses):
        image_dir = f"dataset/Mushrooms_split/train/{genus.value.dir_name}"
        images = (
            os.path.join(image_dir, file)
            for file in os.listdir(image_dir)
            if file.endswith(".jpg")
        )

        image_files = list(itertools.islice(images, 10))

        for i, image_file in enumerate(image_files):
            dirname = os.path.basename(os.path.dirname(image_file))
            filename = os.path.basename(image_file)
            image = Image.open(image_file)
            ax[g, i].imshow(image)
            ax[g, i].axis("off")
            ax[g, i].set_title(dirname + "/\n" + filename[:9] + "...")
    plt.tight_layout()
    plt.show()
    return f
Loading from cache [./cached/charts/mushroom_pics_sample.jpg]

A few thoughts:

  • 9 classes of mushrooms to classify. We could probably map each of those 9 into 2 groups (edible/dangerous), but it might defeat the purpose of the project, so we will leave it as a multi-class classification task.
  • Lots of images, with a wide range of resolutions/proportions.

Let's take a closer look at a few pics to understand how they are structured

In [18]:
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_1():
    f = plt.figure()
    pic1 = "dataset/Mushrooms_split/train/Agaricus/102_BV5Swi4Xfjc.jpg"
    image1 = Image.open(pic1)
    plt.imshow(image1)
    return f
Loading from cache [./cached/charts/mushroom_pics_sample_1.jpg]

We can understand a picture as a matrix of pixels.

We can also retrieve the 3 colour channels independently.

In [19]:
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_2():
    f = plt.figure()
    pic2 = "dataset/Mushrooms_split/train/Cortinarius/070_hnlGwobiKIs.jpg"

    image2 = Image.open(pic2)

    plt.imshow(image2.resize((280, 300)))
    charts.remove_axes()
    return f
    # SVD ~= PCA for removing noise/extra info from images
Loading from cache [./cached/charts/mushroom_pics_sample_2.jpg]
In [20]:
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_3():
    f, ax = plt.subplots(2, 3, figsize=(15, 9))
    pic2 = "dataset/Mushrooms_split/train/Cortinarius/070_hnlGwobiKIs.jpg"

    image2 = Image.open(pic2)

    ax[0, 0].set_title("red channel...")
    ax[0, 1].set_title("green channel...")
    ax[0, 2].set_title("blue channel..")
    ax[0, 0].imshow(np.array(image2)[:, :, 0], cmap="Reds_r")
    ax[0, 1].imshow(np.array(image2)[:, :, 1], cmap="Greens_r")
    ax[0, 2].imshow(np.array(image2)[:, :, 2], cmap="Blues_r")

    ax[1, 0].set_title("...in grayscale")
    ax[1, 1].set_title("...in grayscale")
    ax[1, 2].set_title("...in grayscale")
    ax[1, 0].imshow(np.array(image2)[:, :, 0], cmap="Greys_r")
    ax[1, 1].imshow(np.array(image2)[:, :, 1], cmap="Greys_r")
    ax[1, 2].imshow(np.array(image2)[:, :, 2], cmap="Greys_r")
    return f
Loading from cache [./cached/charts/mushroom_pics_sample_3.jpg]

Checking Classes¶

Let's take a look at the classes that we have and how are they distributed/shaped

In [21]:
classes = glob.glob("dataset/Mushrooms_split/train/*")
classes = {
    path.replace("dataset/Mushrooms_split/train/", "").lower(): path for path in classes
}
classes
Out[21]:
{'boletus': 'dataset/Mushrooms_split/train/Boletus',
 'agaricus': 'dataset/Mushrooms_split/train/Agaricus',
 'amanita': 'dataset/Mushrooms_split/train/Amanita',
 'lactarius': 'dataset/Mushrooms_split/train/Lactarius',
 'suillus': 'dataset/Mushrooms_split/train/Suillus',
 'cortinarius': 'dataset/Mushrooms_split/train/Cortinarius',
 'entoloma': 'dataset/Mushrooms_split/train/Entoloma',
 'hygrocybe': 'dataset/Mushrooms_split/train/Hygrocybe',
 'russula': 'dataset/Mushrooms_split/train/Russula'}
In [22]:
files = {f: c for c, path in classes.items() for f in glob.glob(f"{path}/*.jpg")}
df_files = pd.DataFrame({"path": files.keys(), "genus": files.values()})
df_files.sample(n=10)
Out[22]:
path genus
2764 dataset/Mushrooms_split/train/Cortinarius/105_... cortinarius
1120 dataset/Mushrooms_split/train/Amanita/079_6ZA8... amanita
2537 dataset/Mushrooms_split/train/Suillus/155_hPsV... suillus
4325 dataset/Mushrooms_split/train/Russula/617_B_F5... russula
3846 dataset/Mushrooms_split/train/Russula/036_7WKO... russula
2170 dataset/Mushrooms_split/train/Lactarius/410_gH... lactarius
1578 dataset/Mushrooms_split/train/Lactarius/249_4V... lactarius
4306 dataset/Mushrooms_split/train/Russula/302_Yq1G... russula
1772 dataset/Mushrooms_split/train/Lactarius/1079_D... lactarius
2509 dataset/Mushrooms_split/train/Suillus/097_JAf_... suillus
In [23]:
def get_picture_size(path):
    with Image.open(path) as img:
        width, height = img.size
    return width, height


def get_picture_height(path):
    return get_picture_size(path)[1]


def get_picture_width(path):
    return get_picture_size(path)[0]


df_files["height"] = df_files["path"].map(get_picture_height)
df_files["width"] = df_files["path"].map(get_picture_width)
df_files
Out[23]:
path genus height width
0 dataset/Mushrooms_split/train/Boletus/0201_PO7... boletus 693 960
1 dataset/Mushrooms_split/train/Boletus/0048_wGU... boletus 567 800
2 dataset/Mushrooms_split/train/Boletus/0112_s0d... boletus 600 800
3 dataset/Mushrooms_split/train/Boletus/0664_aAv... boletus 404 570
4 dataset/Mushrooms_split/train/Boletus/1089_w6Z... boletus 517 800
... ... ... ... ...
4359 dataset/Mushrooms_split/train/Russula/342_1iB1... russula 585 780
4360 dataset/Mushrooms_split/train/Russula/044_OC94... russula 600 800
4361 dataset/Mushrooms_split/train/Russula/361_IK7M... russula 533 800
4362 dataset/Mushrooms_split/train/Russula/170_1oaK... russula 535 800
4363 dataset/Mushrooms_split/train/Russula/209_9N--... russula 600 800

4364 rows × 4 columns

In [24]:
@run
@cached_chart()
def image_count_by_genus():
    f = plt.figure()
    order = df_files["genus"].value_counts().index
    sns.countplot(data=df_files, y="genus", order=order, color=moonstone)
    plt.title("count of raw images per genus")
    return f
Loading from cache [./cached/charts/image_count_by_genus.png]

Our aim is to have a few hundreds of examples per class, from the training split.

This means that:

  • some of the rare classes will have repeated images multiple times.
    • this won't be an issue because we will apply some random transformations to the images so they won't actually be "repeated" as much as "derived pics generated from the same original file"
  • some of the common classes will have some pics that won't be used.
    • if this were to be an issue (unlikely) we can easily fix this by increasing the number per class to 2000, 5000, etc... as needed.

For now we will continue with EDA and this resampling/rebalancing will be down later.

Understanding image resolutions¶

Let's take a look at the image resolutions and how they are distributed

In [25]:
resolutions = df_files.groupby("genus")["width", "height"].agg(["min", "max"])
resolutions
Out[25]:
width height
min max min max
genus
agaricus 423 1200 282 906
amanita 275 1280 183 1024
boletus 262 1200 192 948
cortinarius 259 1280 152 1024
entoloma 275 1210 183 921
hygrocybe 528 1200 370 931
lactarius 391 1280 280 1024
russula 400 1200 300 935
suillus 259 1280 184 961
In [26]:
print("   ", "width", "\t", "height")
print("min", df_files["width"].min(), "\t", df_files["height"].min())
print("max", df_files["width"].max(), "\t", df_files["height"].max())
    width 	 height
min 259 	 152
max 1280 	 1024
In [27]:
fig = px.density_heatmap(
    df_files,
    x="height",
    y="width",
    marginal_x="histogram",
    marginal_y="histogram",
    color_continuous_scale=[
        (0, "white"),
        (0.01, "lightgrey"),
        (1, moonstone),
    ],
)

inner_ratio = df_files["width"].max() / df_files["height"].max()
marginal_ratio = 0.2
fig_width = 800
fig_height = fig_width / (inner_ratio + marginal_ratio)

fig.update_xaxes(range=[0, df_files["width"].max()])
fig.update_yaxes(range=[0, df_files["height"].max()])
fig.update_layout(
    plot_bgcolor="white",
    yaxis=dict(autorange="reversed"),
    autosize=False,
    width=fig_width,
    height=fig_height,
    title="heatmap of resolutions",
)

fig.show()
In [28]:
df_files[df_files["height"] > 1000]
Out[28]:
path genus height width
1400 dataset/Mushrooms_split/train/Amanita/382_HSyQ... amanita 1024 970
1502 dataset/Mushrooms_split/train/Lactarius/414_JR... lactarius 1024 1112
1511 dataset/Mushrooms_split/train/Lactarius/0279_D... lactarius 1024 768
1589 dataset/Mushrooms_split/train/Lactarius/418_kI... lactarius 1024 1056
1623 dataset/Mushrooms_split/train/Lactarius/0903_g... lactarius 1024 1273
2284 dataset/Mushrooms_split/train/Lactarius/0129_o... lactarius 1024 682
2742 dataset/Mushrooms_split/train/Cortinarius/070_... cortinarius 1024 940
In [29]:
df_files["w_group"] = df_files["width"] // 100
df_files["h_group"] = df_files["height"] // 100

Let's cluster images in buckets of 100x100 pixels

In [30]:
@run
@cached_chart()
def resolution_clusters():
    heatmap_scaling = 0.6
    f = plt.figure(figsize=(12 * heatmap_scaling, 10 * heatmap_scaling))

    resolution_groups = (
        df_files[["w_group", "h_group", "path"]]
        .pivot_table(index="h_group", columns="w_group", aggfunc="count")
        .droplevel(0, axis=1)
    )

    sns.heatmap(resolution_groups, cmap="Blues", annot=True, fmt=".0f")
    plt.title("resolution clusters")
    plt.xlabel("width x 100")
    plt.ylabel("height x 100")
    return f
Loading from cache [./cached/charts/resolution_clusters.png]

It seems that most images are around the SVGA range 800x800 (or ~ 800x600). These are a good starting point: enough info to discern figures, but not unnecessarily large.

This project requires us to use transfer learning (by starting from a pre-trained model), so we might have to scale them somewhat, but at least they don't seem to suffer from major issues (not too large/small).

Splitting train/test dataset¶

In [31]:
def valid_image_file(filename: str) -> bool:
    # ImageFolder() requires the signature to be:
    # Callable[[str], bool]

    valid_last_bytes = {
        ".jpg": b"\xff\xd9",
        ".png": b"\x60\x82",
        # ".gif": b"\x3b",
    }

    extension = filename[-4:]
    if extension not in valid_last_bytes.keys():
        raise ValueError(f"File extension is unknown. {extension = }")

    with open(filename, mode="rb") as image_file:
        file_content = image_file.read()
        last_bytes = file_content[-2:]
    return last_bytes == valid_last_bytes[extension]
In [32]:
just_resize = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # resnet50 requires this!
        transforms.ToTensor(),
    ]
)
In [33]:
raw_dataset = datasets.ImageFolder(
    "dataset/Mushrooms_split/train/",
    transform=just_resize,
    is_valid_file=valid_image_file,
)

Calculating image metrics¶

Some of the steps and transformations require us to normalize our images.

We want to calculate some metadata/metrics from our training split so we can apply some transformations/normalization that we hope will help our model.

In [34]:
def calculate_mean_std_for_split(dloader):
    mean = 0.0
    std = 0.0
    nb_samples = 0.0

    # Calculate the mean and std on the training set only
    for data, _ in dloader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples

    print(f"Mean: {mean}")
    print(f"Std: {std}")
    return mean, std
In [35]:
train_mean, train_std = calculate_mean_std_for_split(DataLoader(raw_dataset))
Mean: tensor([0.3905, 0.3685, 0.2805])
Std: tensor([0.2296, 0.2095, 0.2032])

We have the mean/std for the raw images in the train split. We can use this to normalize our data for all 4 dataset.

Creating the final versions of our dataloaders¶

Now that we know the mean/std values for our training images, we can create the data loaders for all of the splits, and we can configure them correctly.

Remember:

  • train: image as tensors + resize + normalize + data augmentation techniques
  • hp: image as tensors + resize + normalize (based on training data)
  • val: image as tensors + resize + normalize (based on training data)
  • test: image as tensors + resize + normalize (based on training data)
In [36]:
normalize = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist()),
    ]
)

image_augmentation = transforms.Compose(
    [
        transforms.RandomResizedCrop(800),
        transforms.Resize((400, 400)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.RandomAffine(degrees=100),
        transforms.GaussianBlur(3),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist()),
    ]
)
In [37]:
train_dataset = datasets.ImageFolder(
    "dataset/Mushrooms_split/train/",
    transform=image_augmentation,
    is_valid_file=valid_image_file,
)
In [38]:
hp_dataset = datasets.ImageFolder(
    "dataset/Mushrooms_split/hp/",
    transform=normalize,
    is_valid_file=valid_image_file,
)
In [39]:
val_dataset = datasets.ImageFolder(
    "dataset/Mushrooms_split/val/",
    transform=normalize,
    is_valid_file=valid_image_file,
)
In [40]:
test_dataset = datasets.ImageFolder(
    "dataset/Mushrooms_split/test/",
    transform=normalize,
    is_valid_file=valid_image_file,
)
In [41]:
# using 15 workers; the value suggested by pytorch in its warning
data_load_settings = {"batch_size": 128, "num_workers": 15}

train_loader = DataLoader(
    train_dataset, pin_memory=True, shuffle=True, **data_load_settings
)
hp_loader = DataLoader(hp_dataset, shuffle=False, **data_load_settings)
val_loader = DataLoader(val_dataset, shuffle=False, **data_load_settings)
test_loader = DataLoader(test_dataset, shuffle=False, **data_load_settings)
In [42]:
train_dataset
Out[42]:
Dataset ImageFolder
    Number of datapoints: 4363
    Root location: dataset/Mushrooms_split/train/
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(800, 800), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
               Resize(size=(400, 400), interpolation=bilinear, max_size=None, antialias=warn)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               RandomAffine(degrees=[0.0, 0.0], translate=(0.1, 0.1))
               RandomAffine(degrees=[-100.0, 100.0])
               GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.3905445635318756, 0.3684823215007782, 0.2804551124572754], std=[0.22958891093730927, 0.20953883230686188, 0.2031669169664383])
           )

Let's inspect that the images have been transformed as expected:

For Training:

  • ✅ normalization
  • ✅ cropping
  • ✅ rotations
  • ✅ resizing
In [43]:
images, labels = next(iter(train_loader))

images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))

for i, image in enumerate(images):
    cax = ax[i // 8, i % 8]
    image = np.transpose(image, (1, 2, 0))
    # image = (image * train_std) + train_mean
    image = np.clip(image, 0, 1)

    cax.imshow(image)

plt.show()

For Training:

  • ✅ normalization
  • ❌ no cropping
  • ❌ no rotations
  • ✅ resizing
In [44]:
images, labels = next(iter(hp_loader))

images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))

for i, image in enumerate(images):
    cax = ax[i // 8, i % 8]
    image = np.transpose(image, (1, 2, 0))
    # image = (image * train_std) + train_mean
    image = np.clip(image, 0, 1)

    cax.imshow(image)

plt.show()
In [45]:
images, labels = next(iter(val_loader))

images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))

for i, image in enumerate(images):
    cax = ax[i // 8, i % 8]
    image = np.transpose(image, (1, 2, 0))
    # image = (image * train_std) + train_mean
    image = np.clip(image, 0, 1)

    cax.imshow(image)

plt.show()
In [46]:
images, labels = next(iter(test_loader))
images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))

for i, image in enumerate(images):
    cax = ax[i // 8, i % 8]
    image = np.transpose(image, (1, 2, 0))
    # image = (image * train_std) + train_mean
    image = np.clip(image, 0, 1)

    cax.imshow(image)

plt.show()

It seems that at least one of the images is not a valid example of a mushroom.. The issue is that it will be likely expensive to scan them all and remove the pictures that, while valid image files, are not actually mushrooms.

Something we could do is flag any images that our model predicts with low confidence, to be manually inspected later, since manual inspection of 6000+ pictures is not scalable nor desirable, in this context.

Rebalancing our training dataset¶

All our dataloaders have been configured to normalize images based on the features in the training split ✅

In [47]:
def rebalance(loader, num_samples_per_class=1000):
    train_dataset = loader.dataset
    targets = [label for label in train_loader.dataset.targets]
    class_count = torch.bincount(torch.tensor(targets))
    class_weights = 1.0 / class_count.float()
    print(class_weights)
    weights = class_weights[torch.tensor(targets)]
    print(len(weights), weights)

    num_samples = num_samples_per_class * len(class_count)

    rebalanced_sampler = WeightedRandomSampler(weights, num_samples)

    rebalanced_dataloader = DataLoader(
        dataset=loader.dataset,
        batch_size=loader.batch_size,
        sampler=rebalanced_sampler,
        num_workers=loader.num_workers,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )

    return rebalanced_dataloader
In [48]:
balanced_train_loader = rebalance(train_loader, num_samples_per_class=3000)
del train_loader
tensor([0.0043, 0.0020, 0.0014, 0.0018, 0.0042, 0.0049, 0.0010, 0.0013, 0.0050])
4363 tensor([0.0043, 0.0043, 0.0043,  ..., 0.0050, 0.0050, 0.0050])
In [49]:
@run
@cached_chart()
def plot_classes_after_rebalance():
    all_labels = []
    for _, y in balanced_train_loader:
        all_labels.append(y.flatten())

    all_labels = torch.cat(all_labels)
    class_counts = torch.bincount(all_labels)
    print(class_counts)
    return sns.countplot(x=all_labels.numpy(), color=moonstone)
Loading from cache [./cached/charts/plot_classes_after_rebalance.png]

Much better than the original distribution of classes

Reviewing the work done, so far¶

  • ✅ taken a look at the downloaded dataset
  • ✅ identified and skipped duplicated folders
  • ✅ explored our data to better understand how it is structured (resolution, color layers, class imbalance, etc.. )
  • ✅ validated our data
  • ✅ removed the few images files that seem damaged/corrupted/malformed
  • ✅ split our data into different datasets
  • ✅ configured the train split to be balanced so that rare classes have better representation
  • ✅ configured image transformation techniques to use for data augmentation
  • ✅ configured all other data loaders to use basic transformations (no data augmentation)

Creating our CNN¶

Let's create our Convolutional Neural Network to classify our data samples.

As required, we will use a pretrained model.

Let's try to visualize the structure of this pretrained network, as it is out of the box (without customizations):

(This diagram is generated by keras, from a pre-known model of the same CNN. Good enough for a peek preview, but we will not be able to do the same with our model later).

Understanding our pretrained model¶

In [50]:
if not os.path.exists("cached/resnet50_diagram.png"):
    model = ResNet50(weights="imagenet")
    plot_model(model, to_file="cached/resnet50_diagram.png", show_shapes=True, dpi=96)
In [51]:
Image.open("cached/resnet50_diagram.png")
Out[51]:

The important bit is at the very end. The last 4 or 5 blocks. We will remove the 1000 neuron output and put our 9 neuron output

In [52]:
class MushroomClassifier(pl.LightningModule):
    def __init__(self, mushroom_classes=9, lr=0.001, betas=(0.9, 0.999), eps=1e-8):
        super(MushroomClassifier, self).__init__()
        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        for param in self.model.parameters():
            param.requires_grad = False
        cnn_codes = self.model.fc.in_features
        self.model.fc = nn.Linear(cnn_codes, mushroom_classes)

        self.m_acc = MulticlassAccuracy(num_classes=mushroom_classes)
        self.m_prec = MulticlassPrecision(num_classes=mushroom_classes)
        self.m_recall = MulticlassRecall(num_classes=mushroom_classes)
        self.m_f1 = MulticlassF1Score(num_classes=mushroom_classes)
        self.m_cm = MulticlassConfusionMatrix(num_classes=mushroom_classes)

        self.lr = lr
        self.betas = betas
        self.eps = eps

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)
        self.log("train_loss", loss)
        self.log("train_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
        self.log("train_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
        self.log("train_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
        self.log("train_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)
        self.log("val_loss", loss)
        self.log("val_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
        self.log("val_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
        self.log("val_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
        self.log("val_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        predicted = self.forward(x)
        predicted = torch.argmax(predicted, 1)
        return predicted

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)
        self.log("test_loss", loss)
        self.log("test_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
        self.log("test_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
        self.log("test_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
        self.log("test_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        # RMSprop(self.parameters(), lr=self.lr)
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
            betas=self.betas,
            eps=self.eps,
        )

        return optimizer

A few things to notice:

  • We are using resnet50 (pretrained) as the backbone of our model
  • We have added an output layer with 9 neurons
  • We have created a few metrics so we can track/plot them later.
    • since we are using pytorch lightning, it will automatically calculate the right values at the end of each batch/epoch.
  • The model has parameters for learning rate and decays that will be used later during hyperparam tuning.

Training Configuration¶

In addition to creating this model, we also want our PT trainer to work optimally:

  • It should support early stopping
  • It should store snapshots of the best models as they are trained, so we can use them later.

Early Stopping¶

In [53]:
def new_early_stopping_callback(
    metric_to_monitor="val_loss",
    min_change_to_consider_an_improvement=0.00,
    stop_after_x_epochs_without_improvement=3,
):
    return EarlyStopping(
        monitor=metric_to_monitor,
        min_delta=min_change_to_consider_an_improvement,
        patience=stop_after_x_epochs_without_improvement,
        verbose=False,
        mode="min",
    )

Saving regular checkpoints¶

In [54]:
def new_checkpoint_callback(
    metric_to_monitor="val_loss",
    mode="min",
    filename="checkpoint_resnet50",
    save_top_k=10,
):
    return ModelCheckpoint(
        dirpath="models/training/",
        save_top_k=save_top_k,
        verbose=True,
        auto_insert_metric_name=True,
        monitor=metric_to_monitor,
        mode=mode,
        filename=filename,
    )
In [55]:
mushroom_classifier_model = MushroomClassifier(mushroom_classes=9, lr=0.0001)
In [56]:
if run_entire_notebook():
    callbacks = [new_checkpoint_callback(), new_early_stopping_callback()]
    trainer = Trainer(
        max_epochs=2,
        callbacks=callbacks,
    )
    trainer.fit(
        mushroom_classifier_model,
        train_dataloaders=balanced_train_loader,
        val_dataloaders=val_loader,
    )
skipping optional operation

Let's take a look at a few interesting things:

18.4 K    Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.106    Total estimated model params size (MB)

These numbers seem to match our expectation

  • We are using transfer learning and we only want our last layer to train and learn during backpropagation
    • thus, the "relatively small" number of trainable parameters
    • and the large number of non trainable parameters
  • Let's try to check if the value 18.4K makes sense using simple math:
    • We know that we are only training the weights and biases of the last layer.
    • The calculation for trainable parameters should only include all weights and biases for the last layer

The formula to calculate this is:

$params = (N*\text{inputNeurons} + 1*\text{Bias}) * \text{OutputNeurons}$

So:

$18.4K \approx (\text{Inp} + 1) * 9$

We expect the number of neurons on the previous layer to be something around:

${{18.4k} \over {9}} - 1 \approx \text{Inp} \approx 2043.44\overline{444}$

We can already see that this is very close to a base 2 number (2048), which is likely to be the size of the previous layer!)

Let's extract the number of neurons in the layer previous to the final output layer (the last layer that feeds into our 9 outputs) just to verify it fully:

This is the "out of the box model", this is why it shows 1000 output neurons.

To get the number of neurons before our custom last layer, we can query it directly:

In [57]:
mushroom_classifier_model.model.fc.in_features
Out[57]:
2048

Or we can use torchsummary to inspect the entire CNN:

In [58]:
from torchsummary import summary
In [59]:
summary(mushroom_classifier_model.cuda(), input_size=(3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256, 56, 56]             512
             ReLU-15          [-1, 256, 56, 56]               0
       Bottleneck-16          [-1, 256, 56, 56]               0
           Conv2d-17           [-1, 64, 56, 56]          16,384
      BatchNorm2d-18           [-1, 64, 56, 56]             128
             ReLU-19           [-1, 64, 56, 56]               0
           Conv2d-20           [-1, 64, 56, 56]          36,864
      BatchNorm2d-21           [-1, 64, 56, 56]             128
             ReLU-22           [-1, 64, 56, 56]               0
           Conv2d-23          [-1, 256, 56, 56]          16,384
      BatchNorm2d-24          [-1, 256, 56, 56]             512
             ReLU-25          [-1, 256, 56, 56]               0
       Bottleneck-26          [-1, 256, 56, 56]               0
           Conv2d-27           [-1, 64, 56, 56]          16,384
      BatchNorm2d-28           [-1, 64, 56, 56]             128
             ReLU-29           [-1, 64, 56, 56]               0
           Conv2d-30           [-1, 64, 56, 56]          36,864
      BatchNorm2d-31           [-1, 64, 56, 56]             128
             ReLU-32           [-1, 64, 56, 56]               0
           Conv2d-33          [-1, 256, 56, 56]          16,384
      BatchNorm2d-34          [-1, 256, 56, 56]             512
             ReLU-35          [-1, 256, 56, 56]               0
       Bottleneck-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          32,768
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40          [-1, 128, 28, 28]         147,456
      BatchNorm2d-41          [-1, 128, 28, 28]             256
             ReLU-42          [-1, 128, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]          65,536
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
             ReLU-47          [-1, 512, 28, 28]               0
       Bottleneck-48          [-1, 512, 28, 28]               0
           Conv2d-49          [-1, 128, 28, 28]          65,536
      BatchNorm2d-50          [-1, 128, 28, 28]             256
             ReLU-51          [-1, 128, 28, 28]               0
           Conv2d-52          [-1, 128, 28, 28]         147,456
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 512, 28, 28]          65,536
      BatchNorm2d-56          [-1, 512, 28, 28]           1,024
             ReLU-57          [-1, 512, 28, 28]               0
       Bottleneck-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          65,536
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]         147,456
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]          65,536
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
             ReLU-67          [-1, 512, 28, 28]               0
       Bottleneck-68          [-1, 512, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          65,536
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72          [-1, 128, 28, 28]         147,456
      BatchNorm2d-73          [-1, 128, 28, 28]             256
             ReLU-74          [-1, 128, 28, 28]               0
           Conv2d-75          [-1, 512, 28, 28]          65,536
      BatchNorm2d-76          [-1, 512, 28, 28]           1,024
             ReLU-77          [-1, 512, 28, 28]               0
       Bottleneck-78          [-1, 512, 28, 28]               0
           Conv2d-79          [-1, 256, 28, 28]         131,072
      BatchNorm2d-80          [-1, 256, 28, 28]             512
             ReLU-81          [-1, 256, 28, 28]               0
           Conv2d-82          [-1, 256, 14, 14]         589,824
      BatchNorm2d-83          [-1, 256, 14, 14]             512
             ReLU-84          [-1, 256, 14, 14]               0
           Conv2d-85         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-86         [-1, 1024, 14, 14]           2,048
           Conv2d-87         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
             ReLU-89         [-1, 1024, 14, 14]               0
       Bottleneck-90         [-1, 1024, 14, 14]               0
           Conv2d-91          [-1, 256, 14, 14]         262,144
      BatchNorm2d-92          [-1, 256, 14, 14]             512
             ReLU-93          [-1, 256, 14, 14]               0
           Conv2d-94          [-1, 256, 14, 14]         589,824
      BatchNorm2d-95          [-1, 256, 14, 14]             512
             ReLU-96          [-1, 256, 14, 14]               0
           Conv2d-97         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-98         [-1, 1024, 14, 14]           2,048
             ReLU-99         [-1, 1024, 14, 14]               0
      Bottleneck-100         [-1, 1024, 14, 14]               0
          Conv2d-101          [-1, 256, 14, 14]         262,144
     BatchNorm2d-102          [-1, 256, 14, 14]             512
            ReLU-103          [-1, 256, 14, 14]               0
          Conv2d-104          [-1, 256, 14, 14]         589,824
     BatchNorm2d-105          [-1, 256, 14, 14]             512
            ReLU-106          [-1, 256, 14, 14]               0
          Conv2d-107         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-108         [-1, 1024, 14, 14]           2,048
            ReLU-109         [-1, 1024, 14, 14]               0
      Bottleneck-110         [-1, 1024, 14, 14]               0
          Conv2d-111          [-1, 256, 14, 14]         262,144
     BatchNorm2d-112          [-1, 256, 14, 14]             512
            ReLU-113          [-1, 256, 14, 14]               0
          Conv2d-114          [-1, 256, 14, 14]         589,824
     BatchNorm2d-115          [-1, 256, 14, 14]             512
            ReLU-116          [-1, 256, 14, 14]               0
          Conv2d-117         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-118         [-1, 1024, 14, 14]           2,048
            ReLU-119         [-1, 1024, 14, 14]               0
      Bottleneck-120         [-1, 1024, 14, 14]               0
          Conv2d-121          [-1, 256, 14, 14]         262,144
     BatchNorm2d-122          [-1, 256, 14, 14]             512
            ReLU-123          [-1, 256, 14, 14]               0
          Conv2d-124          [-1, 256, 14, 14]         589,824
     BatchNorm2d-125          [-1, 256, 14, 14]             512
            ReLU-126          [-1, 256, 14, 14]               0
          Conv2d-127         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-128         [-1, 1024, 14, 14]           2,048
            ReLU-129         [-1, 1024, 14, 14]               0
      Bottleneck-130         [-1, 1024, 14, 14]               0
          Conv2d-131          [-1, 256, 14, 14]         262,144
     BatchNorm2d-132          [-1, 256, 14, 14]             512
            ReLU-133          [-1, 256, 14, 14]               0
          Conv2d-134          [-1, 256, 14, 14]         589,824
     BatchNorm2d-135          [-1, 256, 14, 14]             512
            ReLU-136          [-1, 256, 14, 14]               0
          Conv2d-137         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-138         [-1, 1024, 14, 14]           2,048
            ReLU-139         [-1, 1024, 14, 14]               0
      Bottleneck-140         [-1, 1024, 14, 14]               0
          Conv2d-141          [-1, 512, 14, 14]         524,288
     BatchNorm2d-142          [-1, 512, 14, 14]           1,024
            ReLU-143          [-1, 512, 14, 14]               0
          Conv2d-144            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-145            [-1, 512, 7, 7]           1,024
            ReLU-146            [-1, 512, 7, 7]               0
          Conv2d-147           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-148           [-1, 2048, 7, 7]           4,096
          Conv2d-149           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-150           [-1, 2048, 7, 7]           4,096
            ReLU-151           [-1, 2048, 7, 7]               0
      Bottleneck-152           [-1, 2048, 7, 7]               0
          Conv2d-153            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-154            [-1, 512, 7, 7]           1,024
            ReLU-155            [-1, 512, 7, 7]               0
          Conv2d-156            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-157            [-1, 512, 7, 7]           1,024
            ReLU-158            [-1, 512, 7, 7]               0
          Conv2d-159           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-160           [-1, 2048, 7, 7]           4,096
            ReLU-161           [-1, 2048, 7, 7]               0
      Bottleneck-162           [-1, 2048, 7, 7]               0
          Conv2d-163            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-164            [-1, 512, 7, 7]           1,024
            ReLU-165            [-1, 512, 7, 7]               0
          Conv2d-166            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-167            [-1, 512, 7, 7]           1,024
            ReLU-168            [-1, 512, 7, 7]               0
          Conv2d-169           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-170           [-1, 2048, 7, 7]           4,096
            ReLU-171           [-1, 2048, 7, 7]               0
      Bottleneck-172           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                    [-1, 9]          18,441
          ResNet-175                    [-1, 9]               0
================================================================
Total params: 23,526,473
Trainable params: 18,441
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.55
Params size (MB): 89.75
Estimated Total Size (MB): 376.87
----------------------------------------------------------------

The important bit is at the very end:

            ReLU-171           [-1, 2048, 7, 7]               0
      Bottleneck-172           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-173          [-1, 2048, 1, 1]               0  <<< these two are the 
          Linear-174                    [-1, 9]          18,441  <<< important bits
          ResNet-175                    [-1, 9]               0
================================================================

Sure enough: $(2048 + 1) * 9 \approx 18,4k $ that we saw before:

In [60]:
result = (2048 + 1) * 9
print(result)
util.check(result == 18441)
assert result == 18441, "# of parameters does not match our calculations"
18441
✅

Hyperparameter tuning¶

Let's use optuna to perform some hyperparameter tuning.

For CNN's, one of the most critical hp to tune is the learning rate.

Since we want to see how things work under the hood, let's also tune a couple of extra parameters related to decay beta1 and beta2.

In [61]:
def optimize_hyperparams(study_name: str):
    def objective(trial):
        print("providing new parameters from optuna")
        lr = trial.suggest_loguniform("lr", 1e-5, 1e-1)
        beta1 = trial.suggest_uniform("beta1", 0.8, 1.0)
        beta2 = trial.suggest_uniform("beta2", 0.9, 1.0)

        newly_created_model = MushroomClassifier(
            mushroom_classes=9, lr=lr, betas=(beta1, beta2)
        )

        trainer = Trainer(
            max_epochs=60,
            callbacks=[
                new_checkpoint_callback(),
                new_early_stopping_callback(stop_after_x_epochs_without_improvement=6),
                PyTorchLightningPruningCallback(trial, monitor="val_loss"),
            ],
        )

        trainer.fit(
            newly_created_model,
            train_dataloaders=balanced_train_loader,  # hp_loader
            val_dataloaders=val_loader,
        )

        return trainer.callback_metrics["val_loss"].item()

    optuna_utils.create_optuna_study(study_name, allow_resume=True, direction="minimize")
    study = optuna_utils.get_study(study_name).optimize(objective, n_trials=100)
In [62]:
if run_entire_notebook():
    optimize_hyperparams(study_name="mushroom_training_val")
    # optimize_hyperparams(study_name="mushroom_hp_val")
skipping optional operation
In [63]:
@run
@cached_chart()
def loss_best_trial():
    studies = ["mushroom_training_val", "mushroom_hp_val"]

    f, ax = plt.subplots(1, 2, figsize=(12, 5))
    f.suptitle("comparison of hp tuning depending on dataset size")
    for i, study_name in enumerate(studies):
        cax = ax[i]
        cax.set_title(f"loss for best trial {study_name}")
        cax.set_xlabel("epoch")
        cax.set_ylabel("loss")
        study = optuna_utils.get_study(study_name)
        l = sns.lineplot(study.best_trial.intermediate_values, ax=cax, color=moonstone)
        lr = study.best_params["lr"]
        beta1 = study.best_params["beta1"]
        beta2 = study.best_params["beta2"]

        props = dict(boxstyle="round", facecolor="grey", alpha=0.15)  # bbox features
        cax.text(
            1.03,
            0.98,
            f"{lr = :.4f}\n{beta1 = :.4f}\n{beta2 = :.4f}",
            transform=cax.transAxes,
            fontsize=12,
            verticalalignment="top",
            bbox=props,
        )

        cax.set_ylim(0, 2)
        cax.set_xlim(0, 60)

    plt.tight_layout()
    return l
Loading from cache [./cached/charts/loss_best_trial.png]

We can see that a larger dataset size contributes to our model learning better and faster, during hyperparameter tuning.

The decay parameters beta1/beta2 also look much better for the first larger dataset, (beta2 being a lot closer to 1 than beta1) instead of the similar values that it has in the second study

Comparing all trials¶

Since the better tuned model benefits from the larger training dataset, we will only compare performance for that one study, and ignore the rest of optuna studies.

In [64]:
def loss_all_trials(study_name: str, ids=None):
    plt.xlabel("epoch")
    plt.ylabel("loss")
    study = optuna_utils.get_study(study_name)
    best_trial = study.best_trial.number
    for trial in study.trials:
        col = moonstone if trial.number == best_trial else "grey"
        lw = 3 if trial.number == best_trial else 0.3
        label = "Best Trial" if trial.number == best_trial else None
        if not ids:
            plt.title(f"loss for all trials [{study_name}]")
            sns.lineplot(trial.intermediate_values, color=col, linewidth=lw, label=label)
        elif trial.number in ids:
            plt.title(f"loss for trials [{study_name}] {ids}")
            sns.lineplot(
                trial.intermediate_values,
                label=f"trial {trial.number}",
                color=col,
                linewidth=lw,
            )
    plt.legend()
    plt.ylim(0, 2.1)
    return plt.gca()
In [65]:
@run
@cached_chart()
def loss_all_trials_using_training_split():
    return loss_all_trials("mushroom_training_val")
Loading from cache [./cached/charts/loss_all_trials_using_training_split.png]

Comparing the best performers (large dataset)¶

In [66]:
optuna_study = optuna_utils.get_study(study_name="mushroom_training_val")
studies_df = optuna_study.trials_dataframe()
studies_df = studies_df[studies_df["state"] == "COMPLETE"]
longest = studies_df.sort_values("duration", ascending=False)[:5].number
studies_of_interest = set(longest) | set([optuna_study.best_trial.number])
studies_df.loc[list(studies_of_interest)]
Out[66]:
number value datetime_start datetime_complete duration params_beta1 params_beta2 params_lr state
1 1 0.787550 2023-12-10 03:08:46.818755 2023-12-10 03:55:18.139670 0 days 00:46:31.320915 0.873872 0.967798 0.000377 COMPLETE
2 2 1.119920 2023-12-10 03:55:18.157443 2023-12-10 05:03:13.707603 0 days 01:07:55.550160 0.883608 0.982735 0.000024 COMPLETE
4 4 0.769809 2023-12-10 05:21:20.399222 2023-12-10 06:23:34.361549 0 days 01:02:13.962327 0.841713 0.966169 0.000302 COMPLETE
6 6 0.779809 2023-12-10 06:41:43.628434 2023-12-10 07:13:26.086821 0 days 00:31:42.458387 0.853146 0.910811 0.000900 COMPLETE
24 24 0.815131 2023-12-10 08:34:20.263166 2023-12-10 09:03:57.506501 0 days 00:29:37.243335 0.833433 0.955447 0.001120 COMPLETE
In [67]:
loss_all_trials("mushroom_training_val", studies_of_interest)
Out[67]:
<AxesSubplot: title={'center': 'loss for trials [mushroom_training_val] {1, 2, 4, 6, 24}'}, xlabel='epoch', ylabel='loss'>

Comparing HP tuning when using a smaller dataset¶

A second attempt was used to use a smaller dataset split for HP tuning, with the hopes of finding optimal values faster.

However, this has resulted in noticeably worse performance. None of the trials achieve anything close to the ~0.7/0.8 loss on unseen data.

In [68]:
loss_all_trials("mushroom_hp_val")
Out[68]:
<AxesSubplot: title={'center': 'loss for all trials [mushroom_hp_val]'}, xlabel='epoch', ylabel='loss'>

Training our best model¶

In [69]:
study = optuna_utils.get_study(study_name="mushroom_training_val")

print("loss\t", study.best_value)
print("params\t", study.best_params)
loss	 0.7698094248771667
params	 {'lr': 0.0003022396228963802, 'beta1': 0.8417125470622357, 'beta2': 0.966169418528305}
In [70]:
best_model = MushroomClassifier(
    mushroom_classes=9,
    lr=study.best_params["lr"],
    betas=(study.best_params["beta1"], study.best_params["beta2"]),
)
In [71]:
trainer = Trainer(
    max_epochs=60,
    callbacks=[
        new_checkpoint_callback(),
        new_early_stopping_callback(stop_after_x_epochs_without_improvement=6),
    ],
)

trainer.fit(
    best_model,
    train_dataloaders=balanced_train_loader,  # hp_loader
    val_dataloaders=val_loader,
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:634: UserWarning:

Checkpoint directory /home/edu/turing/projects/sprint13-mushrooms/project/models/training exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                      | Params
-------------------------------------------------------
0 | model    | ResNet                    | 23.5 M
1 | m_acc    | MulticlassAccuracy        | 0     
2 | m_prec   | MulticlassPrecision       | 0     
3 | m_recall | MulticlassRecall          | 0     
4 | m_f1     | MulticlassF1Score         | 0     
5 | m_cm     | MulticlassConfusionMatrix | 0     
-------------------------------------------------------
18.4 K    Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.106    Total estimated model params size (MB)
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Training: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 0, global step 211: 'val_loss' reached 1.59601 (best 1.59601), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 1, global step 422: 'val_loss' reached 1.36843 (best 1.36843), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 2, global step 633: 'val_loss' reached 1.20773 (best 1.20773), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 3, global step 844: 'val_loss' reached 1.14454 (best 1.14454), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 4, global step 1055: 'val_loss' reached 1.10301 (best 1.10301), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 5, global step 1266: 'val_loss' reached 1.05823 (best 1.05823), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 6, global step 1477: 'val_loss' reached 1.03442 (best 1.03442), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 7, global step 1688: 'val_loss' reached 0.99475 (best 0.99475), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 8, global step 1899: 'val_loss' reached 0.95910 (best 0.95910), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 9, global step 2110: 'val_loss' reached 0.94587 (best 0.94587), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 10, global step 2321: 'val_loss' reached 0.93859 (best 0.93859), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 11, global step 2532: 'val_loss' reached 0.93307 (best 0.93307), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 12, global step 2743: 'val_loss' reached 0.90248 (best 0.90248), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 13, global step 2954: 'val_loss' reached 0.88191 (best 0.88191), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 14, global step 3165: 'val_loss' reached 0.86941 (best 0.86941), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 15, global step 3376: 'val_loss' reached 0.87006 (best 0.86941), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 16, global step 3587: 'val_loss' reached 0.85080 (best 0.85080), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 17, global step 3798: 'val_loss' reached 0.84879 (best 0.84879), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 18, global step 4009: 'val_loss' reached 0.86033 (best 0.84879), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 19, global step 4220: 'val_loss' reached 0.83894 (best 0.83894), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 20, global step 4431: 'val_loss' reached 0.83431 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 21, global step 4642: 'val_loss' reached 0.84595 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 22, global step 4853: 'val_loss' reached 0.83944 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 23, global step 5064: 'val_loss' reached 0.83701 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 24, global step 5275: 'val_loss' reached 0.83139 (best 0.83139), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 25, global step 5486: 'val_loss' reached 0.82635 (best 0.82635), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 26, global step 5697: 'val_loss' reached 0.81184 (best 0.81184), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 27, global step 5908: 'val_loss' reached 0.80840 (best 0.80840), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 28, global step 6119: 'val_loss' reached 0.80679 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 29, global step 6330: 'val_loss' reached 0.81115 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 30, global step 6541: 'val_loss' reached 0.81302 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 31, global step 6752: 'val_loss' reached 0.81441 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 32, global step 6963: 'val_loss' reached 0.79239 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 33, global step 7174: 'val_loss' reached 0.80597 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 34, global step 7385: 'val_loss' reached 0.81312 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 35, global step 7596: 'val_loss' reached 0.79496 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 36, global step 7807: 'val_loss' reached 0.77536 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 37, global step 8018: 'val_loss' reached 0.79161 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 38, global step 8229: 'val_loss' reached 0.78940 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 39, global step 8440: 'val_loss' reached 0.78410 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 40, global step 8651: 'val_loss' reached 0.77464 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 41, global step 8862: 'val_loss' reached 0.79397 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 42, global step 9073: 'val_loss' reached 0.77888 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 43, global step 9284: 'val_loss' reached 0.77926 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 44, global step 9495: 'val_loss' reached 0.78799 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 45, global step 9706: 'val_loss' reached 0.78845 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: |          | 0/? [00:00<?, ?it/s]
Epoch 46, global step 9917: 'val_loss' reached 0.79021 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10

Assessing performance¶

While training our model, we are comfortably achieving ~0.70 loss on unseen data (val split).

💡 Since we are using PyTorch Lightning, we don't need to manually do model.train/eval() as this is done automatically by the framework

Let's check see how well our model performs on our final test split:

In [72]:
trainer.test(best_model, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: |          | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    test_accuracy_epoch    │    0.15733768045902252    │
│       test_f1_epoch       │    0.1711704581975937     │
│         test_loss         │    0.8650455474853516     │
│   test_precision_epoch    │    0.19689351320266724    │
│     test_recall_epoch     │    0.15733768045902252    │
└───────────────────────────┴───────────────────────────┘
Out[72]:
[{'test_loss': 0.8650455474853516,
  'test_accuracy_epoch': 0.15733768045902252,
  'test_precision_epoch': 0.19689351320266724,
  'test_recall_epoch': 0.15733768045902252,
  'test_f1_epoch': 0.1711704581975937}]

⚠️ remember that these are epoch based and not the final results. The only one we can take into consideration right now is test_loss

✅ We will look at the overall performance further down.

Also, it does not show signs of overfitting ✅ seeing that the train/val loss was 0.80 and this on test (unseen data) is 0.86

Checking out TensorBoard to comparing and visualize performance¶

We can use tensor_board to visualize and compare the different runs of our CNN

In [73]:
predicted_test = trainer.predict(best_model, test_loader)
predicted_test = torch.cat(predicted_test)
predicted_test
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting: |          | 0/? [00:00<?, ?it/s]
Out[73]:
tensor([0, 1, 0,  ..., 8, 8, 8])
In [77]:
actual = [y for x, y in test_loader.dataset]
In [78]:
assert len(predicted_test) == len(actual)
In [79]:
print(f"{len(predicted_test) = }")
len(predicted_test) = 1007
In [80]:
cm = MulticlassConfusionMatrix(num_classes=9, normalize="true")
confusion_matrix = cm(predicted_test, torch.tensor(actual))
In [81]:
sns.heatmap(
    confusion_matrix,
    annot=True,
    fmt="0.00%",
    cmap=sns.light_palette(moonstone, as_cmap=False, n_colors=10),
)
Out[81]:
<AxesSubplot: >

This is clearly terrible. Almost all trial perform poorly out of the box.

None of them get even close to the performance we got when we did HP tuning with the larger dataset.

Evaluating model performance¶

Model prediction performance¶

In [98]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


def calculate_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average="weighted")
    recall = recall_score(y_true, y_pred, average="weighted")
    f1 = f1_score(y_true, y_pred, average="weighted")

    print(f"Accuracy: {accuracy:.3f}")
    print(f"Precision: {precision:.3f}")
    print(f"Recall: {recall:.3f}")
    print(f"F1 Score: {f1:.3f}")


calculate_metrics(predicted_test, torch.tensor(actual))
Accuracy: 0.694
Precision: 0.713
Recall: 0.694
F1 Score: 0.692

Model inference speed¶

It seems that PyTorch is leveraging the GPU properly during training

Let's see how this translates to our training/inference speed

🟩 GPU Training speed¶

The training speed for our model is around 14 items/sec

🟩 GPU Inference speed¶

The training speed for our model is around 14 items/sec

🟧 CPU Inference Speed¶

Let's compare how this compares with inference on CPU.

Since we just want to get a rough idea about performance difference and orders of magnitude.

In [ ]:
@run
def inference_on_cpu():
    cpu_model = best_model.to("cpu")
    cpu_trainer = pl.Trainer(accelerator="cpu")
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True
    )
    cpu_trainer.test(cpu_model, test_loader)

When using CPU, our inference speed drops down to almost 1 item/second (instead of 14/sec)

🟥 Speed using a TPU device¶

While this was not required, I wanted to try to run this model on an old Coral TPU I have laying around.

  • Step 1: convert our Pytorch Lightning model to ONNX, and store
  • Step 2: convert ONNX model to TensorFlow Lite
  • Step 3: run the tf model on the TPU

Despite the effort, I could not get past step 2. The conversion between ONNX and TF was failing consistently.

I am not worried about this, since it is beyond the scope of the requirements, but if you have experience with these devices/frameworks/standards, I'd love to get some of your insights and advice. :)

In [88]:
import torch
from pytorch_lightning import LightningModule


def convert_to_onnx(model: LightningModule, input_sample: torch.Tensor, onnx_path: str):
    model.to_onnx(onnx_path, input_sample)
In [89]:
convert_to_onnx(
    best_model, input_sample=torch.randn(1, 3, 224, 224), onnx_path="best_model.onnx"
)
In [95]:
if run_entire_notebook("onnx"):
    !onnx-tf convert -i best_model.onnx -o best_model.pb
skipping optional operation
==== 🗃️ printing cached output ====
2023-12-10 10:51:49.540305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-10 10:51:49.558228: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-10 10:51:49.558244: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-10 10:51:49.558859: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-10 10:51:49.561853: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-10 10:51:49.866110: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: 

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

  warnings.warn(
2023-12-10 10:51:50,734 - onnx-tf - INFO - Start converting onnx pb to tf saved model
2023-12-10 10:51:50.905293: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.920277: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.920398: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.922045: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.922130: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.922184: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.950698: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.950787: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.950852: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 10:51:50.950904: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8089 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6

The name of the input tensor seems to be correct and standard:

In [91]:
import onnx

model = onnx.load("best_model.onnx")

print([input.name for input in model.graph.input])
['input.1']
In [92]:
import tensorflow as tf


def quantize_and_convert_to_tflite(tf_model_path: str, tflite_model_path: str):
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_quant_model = converter.convert()

    with open(tflite_model_path, "wb") as f:
        f.write(tflite_quant_model)

It seems that there are some incompatibilities to convert our model to tensorflow.

Unfortunately, we ran out of time for this project and cannot pursue this further, but it was fun to try nonetheless.

Summary¶

In our project, we tuned a Convolutional Neural Network (CNN) for mushroom classification using PyTorch Lightning. Our goal was to accurately classify different species of mushrooms based on images, a task with significant implications in fields such as mycology and food safety.

We began by fetching a diverse dataset of mushroom images from kaggle, ensuring a wide representation of different species. We then preprocessed the images and used a data loader for efficient input into our model. To address class imbalance, we implemented a rebalancing strategy, ensuring each class had an equal chance of being represented during training.

Our CNN model was designed to use transfer learning in order to leverage an existing image classification model (resnet50). We utilized various optimization techniques, including adjusting batch sizes, number of workers in the loader, and learning rates using optuna for hyperparameter tuning. We also employed PyTorch Lightning’s advanced features like GPU acceleration and automatic differentiation, which significantly streamlined our training/tuning process.

Throughout the project, we learned valuable lessons. We found that balancing the dataset improved model performance significantly. We also learned to monitor GPU memory usage closely, as it directly impacted the training speed. Furthermore, we discovered the importance of fine-tuning hyperparameters, such as the learning rate and batch size, to optimize model performance.

We also explored and compared performance between CPU/GPU runs, and even attempted to converting the model to tensorflow to try to run it on a usb TPU. This did not work but it was interesting to learn about ONNX and how to translate models from one framework to the others.

Executive Summary¶

Executive Summary

Our CNN mushroom classification model has shown promising results. The key metrics are as follows:

  • Accuracy: 0.694
  • Precision: 0.713
  • Recall: 0.694
  • F1 Score: 0.692

These metrics indicate the model’s ability to classify mushroom species accurately.

Overall, performing reasonably well.

Future work¶

Should we want to focus on keeping people safe (instead of focusing on identifying the right mushroom type), we could also convert this classification task from multiclass to binary into "safe/unsafe" and we would likely get more meaningful results, but this would put this project into a more difficult terrain as it would take on more responsabilities. Right now the decision falls on the user, and so does the responsability to decide for themselves, which is what feels better/safer given the circumstances.